{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "u4UNB9a6jCkU"
   },
   "source": [
    "# L5-B: Packing 2-bit Weights\n",
    "\n",
    "In this lesson, you will learn how to store low precision weights through a technique called \"packing\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ma4RwtOYurm-"
   },
   "source": [
    "## Packing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** Younes will explain the below code, and walk through each iteration step. You can go through the comprehensive explaination in the markdown below after first watching Younes's explaination.\n",
    "\n",
    "```Python\n",
    "# Example Tensor: [1, 0, 3, 2]\n",
    "    # 1 0 3 2 - 01 00 11 10\n",
    "\n",
    "    # Starting point of packed int8 Tensor\n",
    "    # [0000 0000]\n",
    "    \n",
    "    ##### First Iteration Start:\n",
    "    # packed int8 Tensor State: [0000 0000]\n",
    "    # 1 = 0000 0001\n",
    "    # 0000 0001\n",
    "    # No left shifts in the First Iteration\n",
    "    # After bit-wise OR operation between 0000 0000 and 0000 0001:\n",
    "    # packed int8 Tensor State: 0000 0001\n",
    "    ##### First Iteration End\n",
    "\n",
    "    ##### Second Iteration Start:\n",
    "    # packed int8 Tensor State: [0000 0001]\n",
    "    # 0 = 0000 0000\n",
    "    # 0000 0000\n",
    "    # 2 left shifts:\n",
    "    # [0000 0000] (1 shift)-> 0000 0000 (2 shift)-> 0000 0000\n",
    "    # After bit-wise OR operation between 0000 0001 and 0000 0000:\n",
    "    # packed int8 Tensor State: 0000 0001\n",
    "    ##### Second Iteration End\n",
    "\n",
    "    ##### Third Iteration Start:\n",
    "    # packed int8 Tensor State: [0000 0001]\n",
    "    # 3 = 0000 0011\n",
    "    # 0000 0011\n",
    "    # 4 left shifts:\n",
    "    # [0000 0011] (1 shift)-> 0000 0110 (2 shift)-> 0000 1100\n",
    "    # 0000 1100 (3 shift)-> 0001 1000 (4 shift)-> 0011 0000\n",
    "    # After bit-wise OR operation between 0000 0001 and 0011 0000:\n",
    "    # packed int8 Tensor State: 0011 0001\n",
    "    ##### Third Iteration End\n",
    "\n",
    "    ##### Fourth Iteration Start:\n",
    "    # packed int8 Tensor State: [0011 0001]\n",
    "    # 2 = 0000 0010\n",
    "    # 0000 0010\n",
    "    # 6 left shifts:\n",
    "    # [0000 0010] (1 shift)-> 0000 0100 (2 shift)-> 0000 1000\n",
    "    # 0000 1000 (3 shift)-> 0001 0000 (4 shift)-> 0010 0000\n",
    "    # 0010 0000 (5 shift)-> 0100 0000 (6 shift)-> 1000 0000\n",
    "    # After bit-wise OR operation between 0011 0001 and 1000 0000:\n",
    "    # packed int8 Tensor State: 1011 0001\n",
    "    ##### Fourth Iteration End\n",
    "    \n",
    "    # Final packed int8 Tensor State: [1011 0001]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 218,
     "status": "ok",
     "timestamp": 1705676413805,
     "user": {
      "displayName": "Younes Belkada",
      "userId": "15414910276690549281"
     },
     "user_tz": 0
    },
    "height": 539,
    "id": "ErbmOljngvnC"
   },
   "outputs": [],
   "source": [
    "def pack_weights(uint8tensor, bits):\n",
    "    if uint8tensor.shape[0] * bits % 8 != 0:\n",
    "        raise ValueError(f\"The input shape needs to be a mutiple \\\n",
    "        of {8 / bits} - got {uint8tensor.shape[0]}\")\n",
    "\n",
    "    num_values = uint8tensor.shape[0] * bits // 8\n",
    "\n",
    "    num_steps = 8 // bits\n",
    "\n",
    "    unpacked_idx = 0\n",
    "\n",
    "    packed_tensor = torch.zeros((num_values), dtype=torch.uint8)\n",
    "\n",
    "    # 1 0 3 2 - 01 00 11 10\n",
    "\n",
    "    # [0000 0000] -> 0000 0001\n",
    "\n",
    "    # 0000 0001\n",
    "\n",
    "    # 0000 0000 - 0000 0000\n",
    "\n",
    "    # 0000 0011 - 0011 0000 - 0011 0001\n",
    "\n",
    "    # 1011 0001\n",
    "    \n",
    "    for i in range(num_values):\n",
    "        for j in range(num_steps):\n",
    "            packed_tensor[i] |= uint8tensor[unpacked_idx] << (bits * j)\n",
    "            unpacked_idx += 1\n",
    "    return packed_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "unpacked_tensor = torch.tensor([1, 0, 3, 2], \n",
    "                               dtype=torch.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 214,
     "status": "ok",
     "timestamp": 1705676415692,
     "user": {
      "displayName": "Younes Belkada",
      "userId": "15414910276690549281"
     },
     "user_tz": 0
    },
    "height": 29,
    "id": "zcLwmPr1Fdrg",
    "outputId": "f50ccb6b-8ec4-47dc-91c1-bc70db8aafa8"
   },
   "outputs": [],
   "source": [
    "pack_weights(unpacked_tensor, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "unpacked_tensor = torch.tensor([1, 0, 3, 2, 3, 3, 3, 3], \n",
    "                               dtype=torch.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "pack_weights(unpacked_tensor, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyPkr3WbwaCNWYoK5KkH4mJ5",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
